# coding=utf-8
from dataclasses import dataclass
from threading import Lock
from typing import Iterable

import requests

import settings


@dataclass
class StockMeasure:
    code: str
    name: str
    yesterday_price: float
    real_price: float
    date: str
    time: str

    @property
    def rise(self):
        return round((self.real_price - self.yesterday_price) / self.yesterday_price * 100, 2)


class RiseCache:
    def __init__(self):
        self._data = {}
        self._lock = Lock()

    def __str__(self):
        return f"{self._data}"

    def put(self, item: StockMeasure, length):
        with self._lock:
            if item.code not in self._data:
                self._data[item.code] = []

            if len(self._data[item.code]) >= length:
                self._data[item.code].pop(0)

            self._data[item.code].append(item.rise)

    def batch_put(self, items: list[StockMeasure]):
        stocks = settings.STOCKS

        for i in items:
            interval = stocks.amplitude_alert_interval(i.code)
            length = interval // settings.COLLECT_INTERVAL
            self.put(i, length)

    def amplitude(self, code) -> float:
        history_rises = self._data.get(code)
        if history_rises is None:
            return 0

        return max(history_rises) - min(history_rises)


def get_stock_data(stock_codes: Iterable) -> list[StockMeasure]:
    # sina stock api:
    # https://www.cnblogs.com/zeroes/p/sina_stock_api.html
    clean_codes = set()
    for i in stock_codes:
        if i.startswith("6"):
            clean_codes.add("sh" + i)
        elif i.startswith("0") or i.startswith("3"):
            clean_codes.add("sz" + i)
        else:
            clean_codes.add(i)

    query_code_list = ",".join(sorted(clean_codes))
    url = f"https://hq.sinajs.cn/list={query_code_list}"
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36",
        "Referer": "https://stock.sina.com.cn/index/stock/index.shtml"
    }

    items = []
    for line in requests.get(url, headers=headers).text.splitlines():
        code = line.split("=")[0].split("_")[-1]
        data = line.split("=")[1].strip('"').split(",")
        name = data[0]
        yesterday_final_price = float(data[2])
        real_price = float(data[3])
        date = data[30]
        time = data[31]

        items.append(StockMeasure(code, name, yesterday_final_price, real_price, date, time))

    return items


if __name__ == '__main__':
    codes = [
        "002156",
        # "sz002687",
        # "sz002371",
    ]
    print(get_stock_data(codes))
